"""
自动并行配置搜索主要入口
"""
import os

import torch
import logging
import multiprocessing

from mindspeed.auto_settings.config.post_info import PostInfo
from mindspeed.auto_settings.config.model_config import set_model_config, get_model_config, \
    update_model_config
from mindspeed.auto_settings.config.system_config import set_system_config, get_system_config
from mindspeed.auto_settings.profile.profiler import Profiler
from mindspeed.auto_settings.module.searcher import WhiteSearcher, BlackSearcher, MixedSearcher
from mindspeed.auto_settings.search_space import SearchSpace
from mindspeed.auto_settings.utils.file_utils import restricted_read
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.mindspeed_adaptor.mindspeed_executor import ExecutorFlag
from mindspeed.auto_settings.module.search.auto_tuning_disttrain import ModelSpaces, SingleModel


if os.environ.get('AUTO_SETTING_SEARCH_CPU_NUM') is None:
    AUTO_SETTING_SEARCH_CPU_NUM = multiprocessing.cpu_count()
elif int(os.environ.get('AUTO_SETTING_SEARCH_CPU_NUM')) > 1:
    AUTO_SETTING_SEARCH_CPU_NUM = int(os.environ.get('AUTO_SETTING_SEARCH_CPU_NUM'))
else:
    print("Default: set AUTO_SETTING_SEARCH_CPU_NUM = 1")
    AUTO_SETTING_SEARCH_CPU_NUM = 4
 
 
_AUTO_SETTING_FILENAME = "/auto_settings.json"
 
class AutoSettings(object):

    def __init__(self):
        # self.patcher = Patcher()
        self.search_spaces = SearchSpace()
        self.profiler = Profiler()
        self.logger = get_logger("AutoSettings")

    def init(self, args):
        """
        初始化相关配置
        """
        self._init_configs(args)
        self._init_hardware(args)
        self._init_global_group()

    def _init_hardware(self, args):
        self.profiler.run(PostInfo.FILENAME, None, ExecutorFlag.PARSE_ARGS)
        post_info = restricted_read(os.path.join(str(get_system_config().work_dir), PostInfo.FILENAME))
        get_system_config().load_settings(post_info)
        update_model_config(post_info.model_config)

    def _init_configs(self, args):
        """
        初始化相关配置
        """
        set_system_config(args)
        set_model_config(args)

    def _init_global_group(self):
        """
        初始化相关
        """
        sys_config = get_system_config()
        torch.distributed.init_process_group(
            backend=torch.distributed.Backend.GLOO,
            rank=sys_config.node_rank,
            world_size=sys_config.nnodes
        )

    def _get_searcher(self, search_type):
        """
        根据用户配置获取对应的搜索器
        """
        if search_type == "white":
            return WhiteSearcher()
        if search_type == "black":
            return BlackSearcher()
        if search_type == "mixed":
            searcher = MixedSearcher()
            searcher.set_white_topk(5)
            return searcher
        return MixedSearcher()

    def search(self, args):
        """
        搜索入口
        """
        self.logger.info("model config is that:\n%s", str(get_model_config()))
        search_type = args.auto_settings_type
        searcher = self._get_searcher(search_type)
        # 搜索空间
        search_configs = self.search_spaces.build_search_spaces()
        final_configs = searcher.search(configs=search_configs, topk=3)

        return final_configs

    def auto_settings(self, args):
        """
        入口函数
        """
        self.init(args)
        if hasattr(args, "dist_train") and args.dist_train and args.auto_settings_type == "white":
            self.auto_tuning_disttrain(args)
            return
        if get_system_config().node_rank != 0:
            self.profiler.run_on_slaves(args)
            return
        final_configs = self.search(args)
        self.logger.info("<==========Final config generated==========>")
        self.logger.info("The recommended configurations are:")
        for i, final_cfg in enumerate(final_configs):
            if final_cfg:
                self.logger.info("<==========Top #%s config==========>", str(i))
                if self.logger.getEffectiveLevel() == logging.DEBUG:
                    self.logger.debug("\n%s", str(final_cfg))
                else:
                    self.logger.info("\n%s", str(final_cfg))
        self.logger.info("<==========Start training==========>")
        return final_configs


    def auto_tuning_disttrain(self, args):
        model_spaces = ModelSpaces()
        model_spaces.sources_setting.init_settings(args)
        logger = get_logger("main")
        # 拆分多模态到单模型
        #sora的特殊适配
        if hasattr(args, "mm_data") and "sora" in args.mm_data:
            print(126)
            try:
                from mindspeed.auto_settings.module.search.search_engine_sora import SpaceSearch
                model_spaces.space_search_model = SpaceSearch()
                model_spaces.sora_config()
            except Exception as e:
                logger.info(str(e))
        elif hasattr(args, "dist_train") and args.dist_train:
            model_spaces.model_type = "vit+gpt"
            print(134)
            # if hasattr(args, "dist_train"): 
            try:
                from mindspeed.auto_settings.module.search.multimodal_search_surrogate import MultimodalSpaceSearchSurrogate
                model_spaces.space_search_model = MultimodalSpaceSearchSurrogate()
                model_spaces.dist_train_config()
            except Exception as e:
                logger.info(str(e))
        else:
            try:
                from mindspeed.auto_settings.module.search.search_engine import SpaceSearch
                model_spaces.space_search_model = SpaceSearch()
                model_spaces.model_list = [SingleModel(model_spaces.time_recorder)]
                model_spaces.model_list[0].model_settings = model_spaces.sources_setting
                model_spaces.model_list[0].model_config.sub_work_dir = model_spaces.sources_setting.work_dir
            except Exception as e:
                logger.info(str(e))
        
        for model in model_spaces.model_list:
            # if "vit" in model.model_config.mm_model_name:
            #     continue
            model.parse_args(model_spaces.model_type)
            if model.model_settings.node_rank != 0:
                model.model_settings.executor.wait(
                    model_config=model.model_config,
                    gloo_group=model.model_settings.gloo_group
                    )
                return
            model._memory_model()
            model._performance_model()
            model.model_settings.gloo_group = None
        model_spaces.space_search_model._space_search(models=model_spaces.model_list, cpu_num=AUTO_SETTING_SEARCH_CPU_NUM)